# Codes written by Maria E. Frediani (NCAR/RAL; frediani@ucar.edu)

import os
import numpy as np
from netCDF4 import Dataset

# --------------------------------------------------------------------
# --------------------------------------------------------------------


def pyproj_wrf2wgs(wf, **kwargs):
    '''
    Convert wrf projected coordinates to WGS EPSG3857 (google earth projection) using pyproj
    '''

    # use the pyproj.transformer.Transformer and pyproj.crs.CRS
    # in place of the deprecated pyproj.Proj and pyproj.transformer.transform()

    from pyproj import CRS, Transformer, Proj
    from wrf_grid import WRF_Grid
    # WRF_Grid was adapted from https://github.com/NCAR/wrf_hydro_arcgis_preprocessor/blob/master/wrf_hydro_functions.py

    ds = Dataset(wf, 'r')  # geogrid file here!
    src = CRS(WRF_Grid(ds).Projection_String)
    print(WRF_Grid(ds).Projection_String)
    dst = CRS.from_epsg(3857)

    # west_lon,south_lat,east_lon,north_lat
    # AOI = pyproj.transformer.AreaOfInterest(

    latc = ds.corner_lats[-4:]  # [4x XLAT, 4x XLAT_U, 4x XLAT_V, 4x XLAT_C]
    lonc = ds.corner_lons[-4:]

    lat = np.squeeze(ds.variables['XLAT_M'][:])
    lon = np.squeeze(ds.variables['XLONG_M'][:])

    # 1 Convert from lon/lat to native map projection x,y coordinates
    wrf2meters = Proj(src, preserve_units=False)
    wrfext = wrf2meters(lonc, latc)  # wrf2meters(x, y, inverse=True)
    wrfxy = wrf2meters(lon, lat)

    # t0 = Transformer.from_crs(crs_from=src.geodetic_crs,
    #                           crs_to=src,
    #                           always_xy=True)

    wrf2prj = Transformer.from_crs(crs_from=src.geodetic_crs,
                                   crs_to=CRS.from_epsg(3857),
                                   always_xy=True)

    # lowleft, upleft, upright, lowright
    prjext = wrf2prj.transform(lonc, latc)
    prjxy = wrf2prj.transform(lon, lat)

    return src, dst, wrfext, prjext, (lonc, latc), wrfxy, prjxy, (ds.DX, ds.DY)


# --------------------------------------------------------------------
# --------------------------------------------------------------------


def cartopy_defwrfproj(wf, **kwargs):
    '''
    Define wrf projection using cartopy 
    '''

    import cartopy
    ds = Dataset(wf, 'r')

    globe = cartopy.crs.Globe(ellipse=None,
                              semimajor_axis=6370000,
                              semiminor_axis=6370000,
                              nadgrids="@null")

    # Set cutoff to -30 for NH, +30.0 for SH.
    cutoff = -30.0 if ds.MOAD_CEN_LAT >= 0 else 30.0

    return cartopy.crs.LambertConformal(central_longitude=ds.STAND_LON,
                                        central_latitude=ds.MOAD_CEN_LAT,
                                        standard_parallels=(ds.TRUELAT1, ds.TRUELAT2),
                                        globe=globe,
                                        cutoff=cutoff)


# --------------------------------------------------------------------
# --------------------------------------------------------------------


def read_wrffr(ff, **kwargs):
    '''
    Read any wrf variable
    '''

    vnm = kwargs.get('var', 'SP_SUM_DEP')

    with Dataset(ff, 'r') as nc:
        dat = np.squeeze(nc.variables[vnm][:])

    return np.array(dat)


# --------------------------------------------------------------------
# --------------------------------------------------------------------


def mplext_outer(ext):
    '''
    fetch outer box of projected extent to use as mpatplotlib extent
    '''

    # left-right from bottom, top-bottom from left (smaller box)
    return [min(ext[0]), max(ext[0]), min(ext[1]), max(ext[1])]


# --------------------------------------------------------------------
# --------------------------------------------------------------------


def gjson_feature_geometry(filename):

    fjson = filename
    print(fjson)

    import json
    with open(fjson) as f:
        gj = json.load(f)['features']

    return [i for i in gj]


# --------------------------------------------------------------------
# --------------------------------------------------------------------


def gjson2poly(ft, *, prjtrans, **kwargs):

    if 'geometry' in ft.keys():
        poly = ft['geometry']['coordinates']
    else:
        if 'features' in ft.keys():
            poly = [feat['geometry']['coordinates'] for feat in ft['features']]

    poly = flatlist(poly)

    if len(poly) == 1:
        poly = poly[0]
        mappoly = dict(type='Polygon', coordinates=[[list(prjtrans(*i)) for i in poly]])
    else:
        mappoly = dict(type='MultiPolygon',
                       coordinates=[[[prjtrans(*i) for i in mp[0]] for mp in poly]])

    return mappoly


# --------------------------------------------------------------------
# --------------------------------------------------------------------


def flatlist(a):

    while len(a) == 1 and len(a[0]) == 1:
        a = a[0]
    return a


# --------------------------------------------------------------------
# --------------------------------------------------------------------


def load_baseimage_any(basetype, *, bbox):

    from PIL import Image
    from owslib.wms import WebMapService

    fimg = 'usgs' + basetype + '_epsg3857_bbox_' + '_'.join(
        [str(np.round(i, decimals=2)) for i in bbox]) + '.png'

    url1 = 'https://basemap.nationalmap.gov/arcgis/services/{}/MapServer/WmsServer?'.format(
        basetype)
    url2 = 'https://services.nationalmap.gov/arcgis/services/{}/ImageServer/WMSServer?'.format(
        basetype)

    url_dict = dict(USGSNAIPPlus=url2,
                    USGSNAIPImagery=url2,
                    USGSImageryOnly=url1,
                    USGSHydroCached=url1,
                    USGSTopo=url1,
                    USGSShadedReliefOnly=url1)

    if os.path.isfile(fimg):
        img = Image.open(fimg)
    else:
        map_url = url_dict[basetype]
        print(map_url)
        wms = WebMapService(map_url, version='1.3.0')
        baseimg = wms.getmap(layers='0',
                             srs='EPSG:3857',
                             bbox=bbox,
                             format='image/png',
                             size=(4096, 4096),
                             transparent=True)

        img = Image.open(baseimg)
        img.save(fimg)

    return img


# --------------------------------------------------------
# --------------------------------------------------------


def refine_dom(lon, lat, rr):

    from scipy.interpolate import griddata
    # rr: refinement ratio, 0 < rr < 1

    nx, ny = lon.shape
    i, j = np.meshgrid(np.arange(0, nx, 1), np.arange(0, ny, 1), indexing='ij')
    ii, jj = np.meshgrid(np.arange(0, nx - 1 + rr, rr),
                         np.arange(0, ny - 1 + rr, rr),
                         indexing='ij')
    xxlon = griddata(np.array([i.flatten(), j.flatten()]).T,
                     lon.flatten(), (ii, jj),
                     method='linear')
    xxlat = griddata(np.array([i.flatten(), j.flatten()]).T,
                     lat.flatten(), (ii, jj),
                     method='linear')

    return (xxlon, xxlat)


# --------------------------------------------------------------------
# --------------------------------------------------------------------


def ext2bbox(left, right, bottom, top):
    return left, bottom, right, top
